{
 "metadata": {
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9-final"
  },
  "orig_nbformat": 2,
  "kernelspec": {
   "name": "python3",
   "display_name": "Python 3.7.9 64-bit",
   "metadata": {
    "interpreter": {
     "hash": "08bf4b89015df3fa62ec3e5cebfe3c3e5a181176f7e37ebd57af46c3a62ed99e"
    }
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2,
 "cells": [
  {
   "source": [
    "# NIN块"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "import torch\n",
    "from torch import nn, optim\n",
    "import d2lzh_pytorch as d2l\n",
    "\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "def nin_block(in_channels, out_channels, kernel_size, stride, padding):\n",
    "    blk = nn.Sequential(\n",
    "        nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding),\n",
    "        nn.ReLU(),\n",
    "        nn.Conv2d(out_channels, out_channels, kernel_size=1),\n",
    "        nn.ReLU(),\n",
    "        nn.Conv2d(out_channels, out_channels, kernel_size=1),\n",
    "        nn.ReLU()\n",
    "    )\n",
    "\n",
    "    return blk"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "0 output shape :  torch.Size([1, 96, 54, 54])\n1 output shape :  torch.Size([1, 96, 26, 26])\n2 output shape :  torch.Size([1, 256, 26, 26])\n3 output shape :  torch.Size([1, 256, 12, 12])\n4 output shape :  torch.Size([1, 384, 12, 12])\n5 output shape :  torch.Size([1, 384, 5, 5])\n6 output shape :  torch.Size([1, 384, 5, 5])\n7 output shape :  torch.Size([1, 10, 5, 5])\n8 output shape :  torch.Size([1, 10, 1, 1])\n9 output shape :  torch.Size([1, 10])\n"
     ]
    }
   ],
   "source": [
    "net = nn.Sequential(\n",
    "    nin_block(1, 96, kernel_size=11, stride=4, padding=0),\n",
    "    nn.MaxPool2d(kernel_size=3, stride=2),\n",
    "    nin_block(96, 256, kernel_size=5, stride=1, padding=2),\n",
    "    nn.MaxPool2d(kernel_size=3, stride=2),\n",
    "    nin_block(256, 384, kernel_size=3, stride=1, padding=1),\n",
    "    nn.MaxPool2d(kernel_size=3, stride=2),\n",
    "    nn.Dropout(0.5),\n",
    "    nin_block(384, 10, kernel_size=3, stride=1, padding=1),\n",
    "    d2l.GlobalAvgPool2d(),\n",
    "    d2l.FlattenLayer()\n",
    ")\n",
    "\n",
    "X = torch.rand(1, 1, 224, 224)\n",
    "for name, blk in net.named_children():\n",
    "    X = blk(X)\n",
    "    print(name, 'output shape : ', X.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "training on  cuda\n",
      "epoch 1, loss 1.7571, train acc 0.353, test acc 0.694, time 138.2 sec\n",
      "epoch 2, loss 0.7112, train acc 0.742, test acc 0.778, time 137.4 sec\n",
      "epoch 3, loss 0.5249, train acc 0.804, test acc 0.815, time 137.1 sec\n",
      "epoch 4, loss 0.4569, train acc 0.830, test acc 0.838, time 137.4 sec\n",
      "epoch 5, loss 0.4173, train acc 0.844, test acc 0.847, time 138.8 sec\n"
     ]
    }
   ],
   "source": [
    "batch_size = 128\n",
    "train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=224)\n",
    "\n",
    "lr, num_epochs = 0.002, 5\n",
    "optimizer = torch.optim.Adam(net.parameters(), lr=lr)\n",
    "d2l.train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ]
}